{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-10-21 07:08:52] Try to use the default NATS-Bench (topology) path from fast_mode=False and path=/Users/xuanyidong/.torch/NATS-tss-v1_0-3ffb9.pickle.pbz2.\n"
     ]
    }
   ],
   "source": [
    "from nats_bench import create\n",
    "from nats_bench.api_utils import time_string\n",
    "import numpy as np\n",
    "\n",
    "# Create the API for size search space\n",
    "api_tss = create(None, \"tss\", fast_mode=False, verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------ImageNet16-120--------------------------------------------------\n",
      "Best (10676) architecture on validation: |nor_conv_3x3~0|+|nor_conv_1x1~0|nor_conv_1x1~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|\n",
      "Best (857) architecture on       test: |nor_conv_1x1~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|\n",
      "using validation ::: validation = 46.73, test = 46.20\n",
      "\n",
      "using test       ::: validation = 46.38, test = 47.31\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def get_valid_test_acc(api, arch, dataset):\n",
    "    is_size_space = api.search_space_name == \"size\"\n",
    "    if dataset == \"cifar10\":\n",
    "        xinfo = api.get_more_info(\n",
    "            arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False\n",
    "        )\n",
    "        test_acc = xinfo[\"test-accuracy\"]\n",
    "        xinfo = api.get_more_info(\n",
    "            arch,\n",
    "            dataset=\"cifar10-valid\",\n",
    "            hp=90 if is_size_space else 200,\n",
    "            is_random=False,\n",
    "        )\n",
    "        valid_acc = xinfo[\"valid-accuracy\"]\n",
    "    else:\n",
    "        xinfo = api.get_more_info(\n",
    "            arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False\n",
    "        )\n",
    "        valid_acc = xinfo[\"valid-accuracy\"]\n",
    "        test_acc = xinfo[\"test-accuracy\"]\n",
    "    return (\n",
    "        valid_acc,\n",
    "        test_acc,\n",
    "        \"validation = {:.2f}, test = {:.2f}\\n\".format(valid_acc, test_acc),\n",
    "    )\n",
    "\n",
    "def find_best_valid(api, dataset):\n",
    "    all_valid_accs, all_test_accs = [], []\n",
    "    for index, arch in enumerate(api):\n",
    "        valid_acc, test_acc, perf_str = get_valid_test_acc(api, index, dataset)\n",
    "        all_valid_accs.append((index, valid_acc))\n",
    "        all_test_accs.append((index, test_acc))\n",
    "    best_valid_index = sorted(all_valid_accs, key=lambda x: -x[1])[0][0]\n",
    "    best_test_index = sorted(all_test_accs, key=lambda x: -x[1])[0][0]\n",
    "\n",
    "    print(\"-\" * 50 + \"{:10s}\".format(dataset) + \"-\" * 50)\n",
    "    print(\n",
    "        \"Best ({:}) architecture on validation: {:}\".format(\n",
    "            best_valid_index, api[best_valid_index]\n",
    "        )\n",
    "    )\n",
    "    print(\n",
    "        \"Best ({:}) architecture on       test: {:}\".format(\n",
    "            best_test_index, api[best_test_index]\n",
    "        )\n",
    "    )\n",
    "    _, _, perf_str = get_valid_test_acc(api, best_valid_index, dataset)\n",
    "    print(\"using validation ::: {:}\".format(perf_str))\n",
    "    _, _, perf_str = get_valid_test_acc(api, best_test_index, dataset)\n",
    "    print(\"using test       ::: {:}\".format(perf_str))\n",
    "\n",
    "dataset = \"ImageNet16-120\"\n",
    "find_best_valid(api_tss, dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
